# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

"""Dataloaders."""

import os
import logging
import random
import torch
import numpy as np
from torch.utils.data import Dataset
from megatron.core import mpu
from megatron.training import get_args
from megatron.core.utils import log_single_rank
from megatron.legacy.data.data_samplers import MegatronPretrainingSampler
from megatron.legacy.data.data_samplers import MegatronPretrainingRandomSampler

logger = logging.getLogger(__name__)

def build_pretraining_data_loader(dataset, consumed_samples):
    """Build dataloader given an input dataset."""

    if dataset is None:
        return None
    args = get_args()

    if int(os.getenv("USE_EPX", 0)): # Fault tolerance sampler
        from epx import EpxSampler
        import megatron.core.parallel_state as parallel_state
        lcp = parallel_state.get_epx_data_parallel_lcp()
        log_single_rank(logger, logging.INFO, f"Use EpxSampler")

        batch_sampler = EpxSampler(
            total_samples=len(dataset),
            lcp=lcp,
            micro_batch_size=args.micro_batch_size)
    else: # Megatron sampler
        if args.dataloader_type == 'single':
            batch_sampler = MegatronPretrainingSampler(
                total_samples=len(dataset),
                consumed_samples=consumed_samples,
                micro_batch_size=args.micro_batch_size,
                data_parallel_rank=mpu.get_data_parallel_rank(),
                data_parallel_size=mpu.get_data_parallel_world_size())
        elif args.dataloader_type == 'cyclic':
            batch_sampler = MegatronPretrainingRandomSampler(
                dataset,
                total_samples=len(dataset),
                consumed_samples=consumed_samples,
                micro_batch_size=args.micro_batch_size,
                data_parallel_rank=mpu.get_data_parallel_rank(),
                data_parallel_size=mpu.get_data_parallel_world_size(),
                data_sharding=args.data_sharding)
        elif args.dataloader_type == "external":
            # External dataloaders are passed through. User is expected to provide a
            # torch-compatible dataloader and define samplers, if needed.
            return dataset
        else:
            raise Exception('{} dataloader type is not supported.'.format(
                    args.dataloader_type))

    # Torch dataloader.
    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=args.num_workers,
                                       pin_memory=True,
                                       persistent_workers=True if args.num_workers > 0 else False,
                                       )


import sys
for k in sys.modules:
    if k.startswith('megatron'):
        for target in ['build_pretraining_data_loader']:
            if getattr(sys.modules[k], target, None):
                setattr(sys.modules[k], target, build_pretraining_data_loader)
